/*******************************************************************************
 *  Project: TuYuOCR
 *  Purpose: OpenSource OCR Engine
 *  Author: TuYuOCR contributors
 *******************************************************************************
 *  The MIT License (MIT)
 *
 *  Copyright (c) 2019 TuYuOCR contributors
 *
 *  Permission is hereby granted, free of charge, to any person obtaining a copy
 *  of this software and associated documentation files (the "Software"), to
 *deal in the Software without restriction, including without limitation the
 *rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
 *sell copies of the Software, and to permit persons to whom the Software is
 *  furnished to do so, subject to the following conditions:
 *
 *  The above copyright notice and this permission notice shall be included in
 *all copies or substantial portions of the Software.
 *
 *  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 *  IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 *  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 *  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 *  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
 *FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
 *IN THE SOFTWARE.
 *******************************************************************************/

#include "crnn/crnn.hpp"
#include <chrono>
#include <fstream>
#include <iostream>
#include "onnxruntime_cxx_api.h"
#include "spdlog/spdlog.h"

#define ORT_ABORT_ON_ERROR(expr)                         \
  do {                                                   \
    OrtStatus* onnx_status = (expr);                     \
    if (onnx_status != NULL) {                           \
      const char* msg = OrtGetErrorMessage(onnx_status); \
      fprintf(stderr, "%s\n", msg);                      \
      OrtReleaseStatus(onnx_status);                     \
      abort();                                           \
    }                                                    \
  } while (0);

namespace tuyu {

CRNN::CRNN(const std::string& model_name, const std::string& json_filename) {
  InitModel(model_name, json_filename);
}

CRNN::~CRNN() {}

void CRNN::InitModel(const std::string& model_path,
                     const std::string& json_filename) {
  env_ = Ort::Env(ORT_LOGGING_LEVEL_WARNING, "Default");
  session_ =
      Ort::Session(env_, model_path.c_str(), Ort::SessionOptions(nullptr));
  std::fstream ifs(json_filename);
  ifs >> json_alphabet;

}

std::string CRNN::Predict(const cv::Mat& image) {
  cv::Mat out;
  PreprocessImage(image, out);

  int image_width = out.cols;
  int image_height = out.rows;
  
  int model_input_ele_count = image_width * image_height;
  auto memory_info =
      Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);

  const int64_t input_shape[] = {1, 1, image_height, image_width};
  const size_t input_shape_len = sizeof(input_shape) / sizeof(input_shape[0]);
  const size_t model_input_len = model_input_ele_count * sizeof(float);
  
  Ort::Value input_tensor = Ort::Value::CreateTensor<float>(
      memory_info, (float*)out.data, model_input_len, input_shape,
      input_shape_len);
  assert(input_tensor.IsTensor());

  std::vector<int64_t> input_node_dims;  // simplify... this model has only 1
  const char* input_names[] = {"image"};
  const char* output_names[] = {"output"};
  auto output_tensors = session_.Run(Ort::RunOptions{nullptr}, input_names,
                                     &input_tensor, 1, output_names, 1);
  assert(output_tensors.size() == 1 && output_tensors.front().IsTensor());

  auto& output_tensor = output_tensors.front();
  
  float* out_array = output_tensor.GetTensorMutableData<float>();
  auto type_info = output_tensor.GetTypeInfo();
  
  auto tensor_info =
  type_info.GetTensorTypeAndShapeInfo();
  size_t out_num_dims = tensor_info.GetDimensionsCount();
  
  std::vector<int64_t> output_node_dims(out_num_dims);
  tensor_info.GetDimensions(output_node_dims.data(), out_num_dims);

  int64_t T = output_node_dims[0];
  int64_t N = output_node_dims[1];
  int64_t C = output_node_dims[2];
  
  std::vector<int> preds;
  for (int t = 0; t < T; t++) {
    int idx = 0;
    float max_value = -FLT_MAX;
    for (int c = 0; c < C; c++) {
      if (out_array[t * C + c] > max_value) {
        max_value = out_array[t * C + c];
        idx = c;
      }
    }
    SPDLOG_INFO("idx is {}", idx);
    preds.emplace_back(idx);
  }

  std::vector<int> result = GreedyDecode(preds);

  std::string ret_result;
  for (int i = 0; i < result.size(); i++) {
    int idx = result[i];
    ret_result += json_alphabet["alphabet"][idx - 1].get<std::string>();
  }
  SPDLOG_INFO("ret result is {}", ret_result);
  return ret_result;
}

void CRNN::PreprocessImage(const cv::Mat& image, cv::Mat& out) {
  cv::Mat gray = image;
  if (image.channels() == 3) {
    cv::cvtColor(image, gray, cv::COLOR_BGR2GRAY);
  } else {
    gray = image;
  }

  int image_height = image.rows;
  int param_h = 32;
  float ratio = float(param_h) / float(image_height);
  cv::Mat resize_image = gray;
  cv::resize(gray, resize_image, cv::Size(0, 0), 0.571f, ratio);
  resize_image.convertTo(out, CV_32F);
  resize_image = (resize_image - 0.5) / 0.5;
}

}  // namespace tuyu